from fileinput import filename
from genericpath import isdir, isfile
from operator import truediv
from turtle import update
from scipy.io import savemat, loadmat
import scipy.sparse as sparse
from sklearn import metrics

import numpy as np

from attack_utils import train_model, test_model
from data_utils import get_dataset, get_data_params, get_class_map, get_centroids, get_centroid_vec

import matplotlib.pyplot as plt
import os
import argparse

import warnings
import time
import pickle
from sklearn.decomposition import PCA
import cvxpy as cvx
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='mnist_17',help="options: mnist_49, mnist_69, mnist_17, dogfish,cifar10_trial")
parser.add_argument('--model_type',default='svm',help='victim model type: SVM or rlogistic regression')
parser.add_argument('--weight_decay',default=0.09, type=float, help='weight decay for regularizers')
parser.add_argument('--rand_seed',default=1234, type=int, help='seed for random number generator')
parser.add_argument('--epsilon',default=0.03, type=float, help='poisoning ratio')

parser.add_argument('--check_transfer',action="store_true", help='check the transferability of different models')
parser.add_argument('--ldc_compare',default=1.0,type=float, help='compare to ldc by using larger poisoning fraction')
parser.add_argument('--percentile', default=90)
parser.add_argument('--use_slab',action="store_true", help='use oracale slab defense')
parser.add_argument('--use_sphere',action="store_true", help='use sphere defese')
parser.add_argument('--filter_points',action="store_true", help='filter incorrectly classisified points when computing metrics')
parser.add_argument('--use_U',action="store_true", help='use the tight upper bound to see if it gives good projection vector')

args = parser.parse_args()
print(args)
use_sphere = args.use_sphere
use_slab = args.use_slab

def compute_dists_under_Q(
    X, Y,
    Q,
    subtract_from_l2=False, #If this is true, computes ||x - mu|| - ||Q(x - mu)||
    centroids=None,
    class_map=None,
    norm=2):
    """
    Computes ||Q(x - mu)|| in the corresponding norm.
    Returns a vector of length num_examples (X.shape[0]).
    If centroids is not specified, calculate it from the data.
    If Q has dimension 3, then each class gets its own Q.
    """
    if (centroids is not None) or (class_map is not None):
        assert (centroids is not None) and (class_map is not None)
    if subtract_from_l2:
        assert Q is not None
    if Q is not None and len(Q.shape) == 3:
        assert class_map is not None
        assert Q.shape[0] == len(class_map)

    if norm == 1:
        metric = 'manhattan'
    elif norm == 2:
        metric = 'euclidean'
    else:
        raise ValueError('norm must be 1 or 2')

    Q_dists = np.zeros(X.shape[0])
    if subtract_from_l2:
        L2_dists = np.zeros(X.shape[0])

    for y in set(Y):
        if centroids is not None:
            mu = centroids[class_map[y], :]
        else:
            mu = np.mean(X[Y == y, :], axis=0)
        mu = mu.reshape(1, -1)

        if Q is None:   # assume Q = identity
            Q_dists[Y == y] = metrics.pairwise.pairwise_distances(
                X[Y == y, :],
                mu,
                metric=metric).reshape(-1)

        else:
            if len(Q.shape) == 3:
                current_Q = Q[class_map[y], ...]
            else:
                current_Q = Q

            if sparse.issparse(X):
                XQ = X[Y == y, :].dot(current_Q.T)
            else:
                XQ = current_Q.dot(X[Y == y, :].T).T
            muQ = current_Q.dot(mu.T).T

            Q_dists[Y == y] = metrics.pairwise.pairwise_distances(
                XQ,
                muQ,
                metric=metric).reshape(-1)

            if subtract_from_l2:
                L2_dists[Y == y] = metrics.pairwise.pairwise_distances(
                    X[Y == y, :],
                    mu,
                    metric=metric).reshape(-1)
                Q_dists[Y == y] = np.sqrt(np.square(L2_dists[Y == y]) - np.square(Q_dists[Y == y]))

    return Q_dists

# Can speed this up if necessary
def get_removed_data(X, Y, percentile):
    num_classes = len(set(Y))
    num_features = X.shape[1]
    centroids = np.zeros((num_classes, num_features))
    class_map = get_class_map()
    centroids = get_centroids(X, Y, class_map)

    # Get radii for sphere
    sphere_radii = np.zeros(2)
    sphere_filter_ids = np.zeros(X.shape[0])
    dists = compute_dists_under_Q(
        X, Y,
        Q=None,
        centroids=centroids,
        class_map=class_map,
        norm=2)
    for y in set(Y):
        sphere_radii[class_map[y]] = np.percentile(dists[Y == y], percentile)
        sphere_filter_ids[Y==y] = np.logical_or(sphere_filter_ids[Y==y], dists[Y == y] > sphere_radii[class_map[y]])
    # Get vector between centroids
    centroid_vec = get_centroid_vec(centroids)

    # Get radii for slab
    slab_radii = np.zeros(2)
    slab_filter_ids = np.zeros(X.shape[0])
    for y in set(Y):
        dists = np.abs(
            (X[Y == y, :].dot(centroid_vec.T) - centroids[class_map[y], :].dot(centroid_vec.T)))
        slab_radii[class_map[y]] = np.percentile(dists, percentile)
        slab_filter_ids[Y==y] = np.squeeze(dists) > slab_radii[class_map[y]]
    
    filter_ids = np.logical_or(slab_filter_ids,sphere_filter_ids)
    X_filtered, Y_filtered = X[~filter_ids], Y[~filter_ids]
    return X_filtered, Y_filtered
    # return class_map, centroids, centroid_vec, sphere_radii, slab_radii

def remove_points_old(X,Y,use_slab=False,use_sphere=False,defense_pars = None):
    classes = [1,-1]
    if use_slab or use_sphere:
        assert defense_pars is not None
        class_map, centroids, centroid_vec, sphere_radii, slab_radii = defense_pars
        for y in classes:
            centroid = centroids[class_map[y]].reshape(-1)
            centroid_vec = centroid_vec.reshape(-1)
            
            slab_radius = slab_radii[class_map[y]]
            sphere_radius = sphere_radii[class_map[y]]

            if use_slab:
                slab_outlier_scores = np.abs(np.dot(X-centroid,centroid_vec))
                slab_percentile = np.percentile(slab_outlier_scores, args.percentile)
                slab_outlier_ids = slab_outlier_scores > slab_radius
                slab_out_of_percentile = slab_outlier_scores > slab_percentile
                slab_outlier_ids = np.logical_and(slab_outlier_ids,slab_out_of_percentile)
            else:
                slab_outlier_ids = np.zeros(X.shape[0])
            if use_sphere:
                sphere_outlier_scores = np.pnorm(X-centroid,2,axis=1) ** 2
                sphere_percentile = np.percentile(sphere_outlier_scores, args.percentile)
                sphere_out_of_percentile = sphere_outlier_scores > sphere_percentile
                sphere_outlier_ids = sphere_outlier_scores > sphere_radius ** 2
                sphere_outlier_ids  = np.logical_and(sphere_outlier_ids,sphere_out_of_percentile)
            else:
                sphere_outlier_ids= np.zeros(X.shape[0])

            top_percentile_to_filter = np.logical_or(slab_outlier_ids,sphere_outlier_ids)
            X_filtered = X[top_percentile_to_filter,:]
            Y_filtered = Y[top_percentile_to_filter]

            print(X_filtered.shape,Y_filtered.shape)
            sys.exit()
    else:
        X_filtered = X
        Y_filtered = Y
    return X,Y

def main(args):
    if args.dataset == 'imdb':
        args.weight_decay = 0.01
    weighted_center = False

    percentile = int(np.round(float(args.percentile)))
    np.random.seed(args.rand_seed)
    seps = []
    constraints = []
    restrict_imdb = False # only used to see if it can speed up computation, but not really

    if args.dataset != 'cifar10_trial':
        epochs = [0]
    else:
        epochs = [-1, 0, 50, 90, 100, 120] 
    
    for epoch in epochs:
        X_train,Y_train,X_test,Y_test,x_lims = get_dataset(args,epoch) 
        if args.dataset == 'imdb' and restrict_imdb:
            x_min,x_max = x_lims
            if sparse.issparse(x_max):
                x_min = x_min.toarray().reshape(-1)
                x_max = x_max.toarray().reshape(-1)
            # following the setup in Koh et al., (2022)
            x_max[x_max < 1] = 1
            x_max[x_max > 50] = 50
        if args.ldc_compare != 1.0:
            x_min, x_max = x_lims
            x_lims = [args.ldc_compare * x_min, args.ldc_compare * x_max]

        if sparse.issparse(X_train):
            X_train = np.asarray(sparse.csr_matrix.todense(X_train))
            X_test = np.asarray(sparse.csr_matrix.todense(X_test))

        # get the class data so as to better evaluate impact of defenses
        class_map, centroids, centroid_vec, sphere_radii, slab_radii = get_data_params(
            X_train,
            Y_train,
            percentile=percentile) 

        defense_pars = [class_map, centroids, centroid_vec, sphere_radii, slab_radii]

        if args.use_slab or args.use_sphere:
            X_train_filter, Y_train_filter = get_removed_data(X_train, Y_train, percentile)
        else:
            X_train_filter, Y_train_filter = X_train, Y_train
        # full_x,full_y = np.concatenate((X_train,X_test),axis=0), np.concatenate((Y_train,Y_test),axis=0)
        print("--- Train/Test Data Size --- ")
        print(X_train.shape,Y_train.shape,X_test.shape,Y_test.shape)
        if args.dataset != 'cifar10_trial':
            print("Data Search Space: Min {}, Max {}".format(x_lims[0],x_lims[1]))

        print("--- Performance of Clean Models --- ")
        clean_model = train_model(X_train,Y_train,args)       
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc, clean_train_acc, clean_test_acc = test_model(X_train,Y_train,X_train,Y_train,\
        X_test,Y_test,clean_model,args,verbose=True)
        clean_data = [X_train,Y_train,X_test,Y_test,X_train_filter, Y_train_filter]
        box_constraints = x_lims

        eps = args.epsilon
        # print out the important statistics of the dataset vulnerabilities
        sep, constraint = analyze_distirbution(args,clean_data,box_constraints,weighted_center=weighted_center,use_slab=use_slab,\
                            use_sphere=use_sphere,defense_pars=defense_pars,filter_points=args.filter_points)
        seps.append(sep)
        constraints.append(constraint)

        print("Sep/SD:",sep)
        print("Sep/Size:",constraint)

def visualize_data(x_axis_ls,x_centers,y_axis_ls,legends,colors,markers,x_lims,fname, _bias=None):
    plt.clf()
    x_min,x_max = x_lims
    base_scale = 1
    x_pos_center = x_centers[0]
    x_neg_center = x_centers[1]

    for i in range(len(x_axis_ls)):
        base_scale  = base_scale * (i+1)
        X = x_axis_ls[i]
        Y = y_axis_ls[i]
        pos_color, neg_color = colors[i]
        pos_X = X[Y==1]
        pos_var = np.var(pos_X)
        # pos_var = np.mean(np.abs(pos_X-x_pos_center)**2)

        neg_X = X[Y==-1]
        neg_var = np.var(neg_X)
        # neg_var = np.mean(np.abs(neg_X-x_neg_center)**2)

        # use the weighted variance
        overall_var = np.sum(Y==-1)/len(Y) * neg_var + np.sum(Y==1)/len(Y) * pos_var
        overall_std = np.sqrt(overall_var)
        # plot the decision boundary
        plt.scatter(pos_X,base_scale * np.ones(len(pos_X)),color=pos_color,marker=markers[i],\
        label='{} Positive: {}, Var: {:.5f}'.format(legends[i],len(pos_X),pos_var))
        # add the computed center
        plt.scatter(np.array([x_pos_center]),base_scale * np.ones(1),color='m',marker="d",\
        label='Positive Center')

        plt.scatter(neg_X,-1 * base_scale * np.ones(len(neg_X)),color=neg_color,marker=markers[i],\
        label='{} Negative: {}, Var: {:.5f}'.format(legends[i],len(neg_X),neg_var))
        plt.scatter(np.array([x_neg_center]),-1 * base_scale * np.ones(1),color='g',marker="d",\
        label='Negative Center')
        plt.title("Dist Center: {:.3f}, Dist Box: {:.3f}, Ratio: {:.3f}".format(x_pos_center[0][0]-x_neg_center[0][0],\
                x_max-x_min,(x_pos_center[0][0]-x_neg_center[0][0])/(x_max-x_min)))
        print("distance between centers {:.3f}, distance between box {:.3f}, ratio: {:.3f}, var:{:.3f}, ratio/sigma: {:.3f}".format(\
            x_pos_center[0][0]-x_neg_center[0][0],x_max-x_min,(x_pos_center[0][0]-x_neg_center[0][0])/(x_max-x_min),\
                overall_std,(x_pos_center[0][0]-x_neg_center[0][0])/overall_std))

        plt.legend()
        plt.xlim(x_min-1,x_max+1)
    plt.savefig(fname)

    return (x_pos_center[0][0]-x_neg_center[0][0])/overall_std, (x_pos_center[0][0]-x_neg_center[0][0])/(x_max-x_min) 

def visualize_trend(x_axis_ls,y_axis_ls,legends,linestyles,colors,fname,feasible_biases,sign_feasible_biases):
    plt.clf()
    plt.axvline(x=feasible_biases[0],color='k',label='sign: {}'.format(sign_feasible_biases[0]))
    plt.axvline(x=feasible_biases[1],color='r',label='sign: {}'.format(sign_feasible_biases[1]))
    for i in range(len(y_axis_ls)):
        plt.plot(x_axis_ls[i],y_axis_ls[i],label=legends[i],linestyle=linestyles[i],color=colors[i])
        # add the plots with particular lines
        plt.xlabel('Bias from Low to High')
        plt.ylabel('Losses')
        plt.legend()
        plt.yscale('log')
        plt.savefig(fname)


def cvx_dot(a,b):
    # return cvx.sum_entries(cvx.mul_elemwise(a, b))
    return cvx.sum(cvx.multiply(a, b))

def get_proj_box_constraint(box_constraints,w,verbose=False,use_slab=False,use_sphere=False,defense_pars=None,\
    dataset='mnist_17',discrete=False):
    # obtain the min and max values of the projected values along the given decision boundary
    if use_slab or use_sphere:
        assert defense_pars is not None
        class_map, centroids, centroid_vec, sphere_radii, slab_radii = defense_pars
    classes = [1,-1]
    x_min, x_max = box_constraints
    if dataset == 'adult':
        assert discrete
        cvx_x_real = cvx.Variable(4)
        cvx_x_binary = cvx.Variable(w.shape[0]-4, boolean=True)
        cvx_x = cvx.hstack([cvx_x_real,cvx_x_binary])
    else:
        cvx_x = cvx.Variable(w.shape[0])
    cvx_w = cvx.Parameter(w.shape[0])
    cvx_w.value = w
    cvx_constraints = []
    cvx_constraints.append(cvx_x >= x_min)
    cvx_constraints.append(cvx_x <= x_max)
    proj_x_min, proj_x_max = 1e10, -1e10

    for y in classes:
        cvx_loss = -y * cvx_dot(cvx_w,cvx_x)
        if use_slab or use_sphere:
            cvx_centroid = cvx.Parameter(w.shape[0])
            cvx_centroid_vec = cvx.Parameter(w.shape[0])
            cvx_slab_radius = cvx.Parameter(1)
            cvx_sphere_radius = cvx.Parameter(1)

            # assign the variable values
            cvx_centroid.value = centroids[class_map[y]].reshape(-1)
            cvx_centroid_vec.value = centroid_vec.reshape(-1)
            cvx_slab_radius = [slab_radii[class_map[y]]]
            cvx_sphere_radius.value = [sphere_radii[class_map[y]]]

        if use_sphere:
            cvx_constraints.append(cvx.pnorm(cvx_x - cvx_centroid, 2) ** 2 <= cvx_sphere_radius ** 2)
        if use_slab:
            cvx_constraints.append(cvx_dot(cvx_centroid_vec, cvx_x - cvx_centroid) <= cvx_slab_radius)
            cvx_constraints.append(-cvx_dot(cvx_centroid_vec, cvx_x - cvx_centroid) <= cvx_slab_radius)

        # solve the maximization to obtain upper bound for the given projection
        cvx_objective = cvx.Maximize(cvx_loss)
        cvx_prob = cvx.Problem(cvx_objective,cvx_constraints)
        proj_x_max_tmp = cvx_prob.solve(verbose=verbose, solver=cvx.GUROBI)
        if proj_x_max_tmp > proj_x_max:
            proj_x_max = proj_x_max_tmp 
            best_x_max = np.array(cvx_x.value)
            best_y_max = y
        # solve the minimization to obtain the lower bound for the given projection
        cvx_objective = cvx.Minimize(cvx_loss)
        cvx_prob = cvx.Problem(cvx_objective,cvx_constraints) 
        proj_x_min_tmp = cvx_prob.solve(verbose=verbose, solver=cvx.GUROBI)
        if proj_x_min_tmp < proj_x_min:
            proj_x_min = proj_x_min_tmp
            best_x_min = np.array(cvx_x.value)
            best_y_min = y

    if discrete:
        if dataset != 'adult':
            best_x_min = np.rint(best_x_min)
            best_x_max = np.rint(best_x_max)
        elif dataset == 'adult':
            best_x_min[4:] = np.rint(best_x_min[4:])
            best_x_max[4:] = np.rint(best_x_max[4:])

        proj_x_min = -best_y_min * np.dot(w,best_x_min)
        proj_x_max = -best_y_max * np.dot(w,best_x_max)

    return [proj_x_min,proj_x_max]

def train_best_bias(proj_x,proj_y,clean_data,args):
    # for 1d case, we first train a model with weight and bias with the form a*x'+b,
    # where x' = w^Tx. Therefore, we should use a*w^T as the new weight vector and still use
    # b as the final bias term
    # if len(proj_x.shape) == 1:
    #     proj_x = proj_x.reshape(-1,1)
    proj_x_train,proj_y_train,proj_x_test,proj_y_test = clean_data

    print(proj_x.shape,proj_x_train.shape,proj_x_test.shape)    
    best_model = train_model(proj_x,proj_y,args)
    _w = best_model.coef_.reshape(-1)
    _bias = best_model.intercept_.reshape(-1)
    # w_ = np.sign(_w) * _w / np.abs(_w)
    # bias_ = np.sign(_w) * _bias / np.abs(_w)

    # check the model performance 
    _, _, clean_train_loss, clean_test_loss, _, _, clean_train_acc,\
        clean_test_acc = test_model(proj_x,proj_y,proj_x_train,proj_y_train,proj_x_test,proj_y_test,\
        best_model,args,verbose=False)    

    print("resultant w:",_w, _bias)
    print("performance: test {}, train {}".format(1-clean_test_acc,1-clean_train_acc))

    return np.sign(_w), np.abs(_w), _bias, best_model

def compute_loss(margin, loss_type = 'svm'):
    if loss_type == 'svm':
        return np.maximum(0,1-margin)
    elif loss_type == 'lr': 
        return np.log(1+np.exp(-margin))
    else:
        raise NotImplementedError

def compute_max_loss(proj_x_lims, best_b, calib_w, classes):
    # for 1D case, we just need to check the extreme position of x to compute the max loss
    # as we know the projection matrix w and the best bias term.
    proj_x_min, proj_x_max = proj_x_lims
    max_loss = -1e10
    for proj_x_lim in proj_x_lims:
        for cls in classes:
            margin = cls * (calib_w * proj_x_lim + best_b) 
            tmp_max_loss = np.mean(compute_loss(margin),axis=0)
            if max_loss < tmp_max_loss:
                max_loss = tmp_max_loss
    return max_loss

def analyze_distirbution(args,clean_data,box_constraints,weighted_center=False,\
                         use_slab=False,use_sphere=False,defense_pars=None,filter_points=False,use_all=False):
    if use_slab or use_sphere:
        print("Using Defense!")
        assert defense_pars is not None

    if args.dataset in ['adult','enron','imdb']:
        discrete = True
    else:
        discrete = False
    X_train,Y_train,X_test,Y_test,X_train_filter, Y_train_filter = clean_data

    if use_all:
        X_total = np.concatenate((X_train_filter,X_test),axis=0)
        Y_total = np.concatenate((Y_train_filter,Y_test),axis=0) 
    else:
        X_total = np.copy(X_test)
        Y_total = np.copy(Y_test)

    print("Pre filter shape:",X_total.shape,Y_total.shape)
    clean_model = train_model(X_train,Y_train,args)
    # filter out points
    X_total_preds = clean_model.predict(X_total)
    if filter_points:
        corr_ids = X_total_preds == Y_total
    else:
        corr_ids = np.ones(len(Y_total),dtype=bool)
    X_total = X_total[corr_ids]
    Y_total = Y_total[corr_ids]
    print("After filter shape:",X_total.shape,Y_total.shape)
    if not args.use_U:
        _w = clean_model.coef_.reshape(-1)
        _bias = clean_model.intercept_.reshape(-1)
    else:
        fdir = '../indiscriminate_attack_clean/files/target_classifiers/{}/svm/0.09/min_U'.format(args.dataset)
        fname = '{}/iter-30000_lr-0.03_eps-{}.npz'.format(fdir,args.epsilon)
        file_to_read = open(fname,"rb")
        f = np.load(file_to_read)
        _w = f['best_theta']
        _bias = f['best_bias']
        U = f['U']
        print("Loaded model of upper bound {} at eps {}".format(U,args.epsilon))

    proj_x_train = np.matmul(X_train_filter,_w).reshape(-1,1)
    proj_x_test = np.matmul(X_test,_w).reshape(-1,1)
    proj_x_total = np.matmul(X_total,_w).reshape(-1,1)
    proj_x_min, proj_x_max = get_proj_box_constraint(box_constraints,_w,\
            use_slab=use_slab,use_sphere=use_sphere,defense_pars=defense_pars,
            dataset = args.dataset, discrete=discrete)

    #  find the centers of the dataset
    X_pos = X_total[Y_total == 1]
    Y_pos = Y_total[Y_total == 1]
    X_neg = X_total[Y_total == -1]
    Y_neg = Y_total[Y_total == -1]
    if weighted_center:
        # weight different samples based on their normalized bias term
        pos_margins = Y_pos*(X_pos.dot(_w) + _bias)
        total_margin = np.sum(pos_margins) #np.sum(np.abs(pos_margins))# 
        weighted_X_pos = X_pos * pos_margins[:,np.newaxis] / total_margin  
        X_pos_center =np.sum(weighted_X_pos,axis=0)
        # m * c[:, np.newaxis]

        neg_margins = Y_neg*(X_neg.dot(_w) + _bias)
        total_margin = np.sum(neg_margins) #np.sum(np.abs(neg_margins)) # 
        weighted_X_neg = X_neg * neg_margins[:,np.newaxis] / total_margin 
        X_neg_center =np.sum(weighted_X_neg,axis=0) 
        center_type = 'margin_weighted'
    else:
        center_type = 'mean'
        X_pos_center = np.mean(X_total[Y_total == 1],axis=0)
        X_neg_center = np.mean(X_total[Y_total == -1],axis=0)
    proj_x_pos_center = np.matmul(X_pos_center,_w).reshape(-1,1)
    proj_x_neg_center = np.matmul(X_neg_center,_w).reshape(-1,1)
    proj_x_centers = [proj_x_pos_center,proj_x_neg_center]
    # visualize the figures
    fig_dir = 'files/figures/{}/{}/{}'.format(args.dataset,args.model_type,args.weight_decay)
    if not os.path.isdir(fig_dir):
        os.makedirs(fig_dir)
    fname = '{}/clean_boundary_{}.png'.format(fig_dir,center_type)
    # proj_xs = [proj_x_train,proj_x_test]
    proj_xs = [proj_x_total]
    # ys = [Y_train,Y_test]
    ys = [Y_total]
    # legends = ['Train','Test']
    legends=['']
    colors = [('b','r'),('b','r')]
    markers = ['o','x']
    x_lims = [proj_x_min,proj_x_max]

    sep, constraint = visualize_data(proj_xs,proj_x_centers,ys,legends,colors,markers,x_lims,fname,_bias)    
    return sep, constraint

def best_performance_on_proj_basis(args,clean_data,proj_basis,box_constraints):
    X_train,Y_train,X_test,Y_test = clean_data
    num_basis = len(proj_basis)
    test_losses = []
    train_losses = []
    train_errors = []
    test_errors = []
    max_losses = []
    calibrated_ws = np.zeros(((num_basis),proj_basis.shape[1]+1))
    classes = [-1,1]
    for i in range(num_basis):
        w = proj_basis[i]   
        proj_x_train = np.matmul(X_train,w).reshape(-1,1)
        proj_x_test = np.matmul(X_test,w).reshape(-1,1)
        proj_x_min, proj_x_max = get_proj_box_constraint(box_constraints,w)
        proj_x_lims = [proj_x_min,proj_x_max]
        # obtain the best model weight and bias for the projected vectors
        sign_, abs_, best_bias, best_model = train_best_bias(proj_x_train,Y_train,proj_x_test,Y_test,args)    
        calibrated_ws[i,:-1] = sign_ * abs_ * w
        calibrated_ws[i,-1] = best_bias
        # record them
        total_train_loss, poison_train_loss, clean_train_loss, clean_test_loss, total_train_acc, poison_train_acc,\
             clean_train_acc, clean_test_acc = test_model(proj_x_train,Y_train,proj_x_train,Y_train,proj_x_test,Y_test,best_model,args,verbose=False)
        test_losses.append(clean_test_loss)
        train_losses.append(clean_train_loss)
        test_errors.append(1-clean_test_acc)
        train_errors.append(1-clean_train_acc)
        # record max loss for each basis
        max_loss = compute_max_loss(proj_x_lims, best_bias, sign_ * abs_, classes)
        max_losses.append(max_loss)
    # save the max losses
    test_losses = np.array(test_losses)
    train_losses = np.array(train_losses)
    test_errors = np.array(test_errors)
    train_errors = np.array(train_errors)
    max_losses = np.array(max_losses)
    return test_errors, test_losses, train_errors, train_losses, max_losses, calibrated_ws

def performance_on_proj_basis(args,clean_data,proj_basis,box_constraints,eps,num_division=100,all_range_bias=True):
    X_train,Y_train,X_test,Y_test = clean_data
    tmp_model = train_model(X_train,Y_train,args)
    test_losses = {}
    train_losses = {}
    train_errors = {}
    test_errors = {}
    proj_x_train = {}
    proj_x_test = {}
    biases_dict = {}
    num_basis = len(proj_basis)
    for i in range(num_basis):
        test_losses[i] = []
        train_losses[i] = []
        test_errors[i] = []
        train_errors[i] = []
    # vary the bias term to check the changes in the bias term
    for i in range(num_basis):
        w = proj_basis[i]
        proj_x_train[i] = np.matmul(X_train,w)
        proj_x_test[i] = np.matmul(X_test,w)
        tmp_model.coef_ = np.array([w])

        proj_x_min, proj_x_max = get_proj_box_constraint(box_constraints,w)

        biases = np.linspace(proj_x_min, proj_x_max, num=num_division)
        biases_dict[i] = biases
        # vary the bias term to see the trend of induced model performance
        for j in range(len(biases)):
            tmp_model.intercept_ = [biases[j]]
            _, _, clean_train_loss, clean_test_loss, _, _, clean_train_acc,\
                clean_test_acc = test_model(X_train,Y_train,X_train,Y_train,X_test,Y_test,\
                tmp_model,args,verbose=False)
            # record the terms
            test_losses[i].append(clean_test_loss)
            train_losses[i].append(clean_train_loss)
            test_errors[i].append(1-clean_test_acc)
            train_errors[i].append(1-clean_train_acc)

    for i in range(num_basis):
        test_losses[i] = np.array(test_losses[i])
        train_losses[i] = np.array(train_losses[i])
        test_errors[i] = np.array(test_errors[i])
        train_errors[i] = np.array(train_errors[i])
    return test_losses,train_losses,test_errors,train_errors,proj_x_train,proj_x_test,biases_dict

main(args)